{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import re\n",
    "import time\n",
    "import collections\n",
    "import os\n",
    "import itertools\n",
    "from sklearn.cross_validation import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_start_end(string):\n",
    "    string = string.split()\n",
    "    strings = []\n",
    "    for s in string:\n",
    "        s = list(s)\n",
    "        s[0] = '<%s'%(s[0])\n",
    "        s[-1] = '%s>'%(s[-1])\n",
    "        strings.extend(s)\n",
    "    return strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000 10000\n"
     ]
    }
   ],
   "source": [
    "with open('lemmatization-en.txt','r') as fopen:\n",
    "    texts = fopen.read().split('\\n')\n",
    "after, before = [], []\n",
    "for i in texts[:10000]:\n",
    "    splitted = i.encode('ascii', 'ignore').decode(\"utf-8\").lower().split('\\t')\n",
    "    if len(splitted) < 2:\n",
    "        continue\n",
    "    \n",
    "    after.append(add_start_end(splitted[0]))\n",
    "    before.append(add_start_end(splitted[1]))\n",
    "    \n",
    "print(len(after),len(before))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['GO', 0], ['PAD', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 65\n",
      "Most common words [('e', 9763), ('i', 7229), ('n', 6177), ('s>', 6095), ('o', 5769), ('a', 5633)]\n",
      "Sample data [46, 5, 11, 14, 35, 47, 4, 6, 10, 39] ['<f', 'i', 'r', 's', 't>', '<t', 'e', 'n', 't', 'h>']\n",
      "filtered vocab size: 69\n",
      "% of vocab used: 106.15%\n"
     ]
    }
   ],
   "source": [
    "concat_from = list(itertools.chain(*before))\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])\n",
    "print('filtered vocab size:',len(dictionary_from))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_from)/vocabulary_size_from,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<1>', 'EOS']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(len(after)):\n",
    "    after[i].append('EOS')\n",
    "after[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(before, after, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sent2idx(sent, vocab, UNK=UNK):\n",
    "    tokens = sent\n",
    "    oovs = []\n",
    "    extend_tokens = []\n",
    "    tokenized = []\n",
    "    for token in tokens:\n",
    "        if token not in vocab:\n",
    "            tokenized.append(UNK)\n",
    "            if token not in oovs:\n",
    "                oovs.append(token)\n",
    "            extend_tokens.append(len(vocab) + oovs.index(token))\n",
    "        else:\n",
    "            extend_tokens.append(vocab[token])\n",
    "            tokenized.append(vocab[token])\n",
    "    return tokenized, extend_tokens, oovs\n",
    "\n",
    "def target2idx(sent, oovs, vocab,UNK=UNK):\n",
    "    tokens = sent\n",
    "    tokenized = []\n",
    "    for token in tokens:\n",
    "        if token not in vocab:\n",
    "            if token not in oovs:\n",
    "                tokenized.append(UNK)\n",
    "            else:\n",
    "                tokenized.append(len(vocab) + oovs.index(token))\n",
    "        else:\n",
    "            tokenized.append(vocab[token])\n",
    "    return tokenized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import collections\n",
    "from tensorflow.python.util import nest\n",
    "from tensorflow.contrib.framework.python.framework import tensor_util\n",
    "from tensorflow.python.layers.core import Dense\n",
    "\n",
    "class CopyNetWrapperState(\n",
    "    collections.namedtuple(\"CopyNetWrapperState\", (\"cell_state\", \"last_ids\", \"prob_c\"))):\n",
    "    def clone(self, **kwargs):\n",
    "        def with_same_shape(old, new):\n",
    "            \"\"\"Check and set new tensor's shape.\"\"\"\n",
    "            if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):\n",
    "                return tensor_util.with_same_shape(old, new)\n",
    "            return new\n",
    "\n",
    "        return nest.map_structure(\n",
    "            with_same_shape,\n",
    "            self,\n",
    "            super(CopyNetWrapperState, self)._replace(**kwargs))\n",
    "\n",
    "class CopyNetWrapper(tf.contrib.rnn.RNNCell):\n",
    "    '''\n",
    "    A copynet RNN cell wrapper\n",
    "    '''\n",
    "    def __init__(self, cell, source_extend_tokens, max_oovs, encode_output, output_layer, vocab_size, name=None):\n",
    "        '''\n",
    "        Args:\n",
    "            - cell: the decoder cell\n",
    "            - source_extend_tokens: input tokens with oov word ids\n",
    "            - max_oovs: max number of oov words in each batch\n",
    "            - encode_output: the output of encoder cell\n",
    "            - output_layer: the layer used to map decoder output to vocab distribution\n",
    "            - vocab_size: this is target vocab size\n",
    "        '''\n",
    "        super(CopyNetWrapper, self).__init__(name=name)\n",
    "        self.cell = cell\n",
    "        self.source_extend_tokens = source_extend_tokens\n",
    "        self.encode_output = encode_output\n",
    "        self.max_oovs = max_oovs\n",
    "        self.output_layer = output_layer\n",
    "        self._output_size = vocab_size + max_oovs\n",
    "        self.copy_layer = Dense(self.cell.output_size,activation=tf.tanh,use_bias=False,name=\"Copy_Weigth\")\n",
    "\n",
    "    def __call__(self, inputs, state):\n",
    "        last_ids = state.last_ids\n",
    "        prob_c = state.prob_c\n",
    "        cell_state = state.cell_state\n",
    "        \n",
    "        '''\n",
    "            - Selective read\n",
    "                At first, my implement is based on the paper, which looks like belowing:\n",
    "                    mask = tf.cast(tf.equal(tf.expand_dims(last_ids,1),self.source_extend_tokens), tf.float32)\n",
    "                    pt = mask * prob_c\n",
    "                    pt_sum = tf.reduce_sum(pt, axis=1)\n",
    "                    pt = tf.where(tf.less(pt_sum, 1e-7), pt, pt / tf.expand_dims(pt_sum, axis=1))\n",
    "                    selective_read = tf.einsum(\"ijk,ij->ik\",self.encode_output, pt)\n",
    "                It looks OK for me, but I got NAN loss after one training step. I tried tf.Print to debug, but cannot find \n",
    "                the problem. Then I tried tfdbg to detect the source of NAN and found it refer to the tf.where code piece, \n",
    "                which is 'pt / tf.expand_dims(pt_sum, axis=1)'. \n",
    "                I am not sure why this code will got NANs. Maybe pt and pt_sum are not of the same shape, due to the broadcasting,\n",
    "                we will got NANs, and these NANs will be used to calculate the gradients, and cause NANs values.\n",
    "                Take a close look at the paper, we will find that p(y_t,c|.) is same for all the same input tokens. And selective_read\n",
    "                is sum of p_t * h_t, in which p_t is 1/K *(p(x_t,c|.)), K is sum of p(x_t,c|.). So p_t is equal to:\n",
    "                p_t = 1 / (number of times x_t shows in input tokens)\n",
    "        '''\n",
    "        # get selective read\n",
    "        # batch * input length\n",
    "        mask = tf.cast(tf.equal(tf.expand_dims(last_ids,1),self.source_extend_tokens), tf.float32)\n",
    "        mask_sum = tf.reduce_sum(mask, axis=1)\n",
    "        mask = tf.where(tf.less(mask_sum, 1e-7), mask, mask / tf.expand_dims(mask_sum, axis=1))\n",
    "        pt = mask * prob_c\n",
    "        selective_read = tf.einsum(\"ijk,ij->ik\",self.encode_output, pt)\n",
    "\n",
    "        inputs = tf.concat([inputs, selective_read], axis=-1)\n",
    "        outputs, cell_state = self.cell(inputs, cell_state)\n",
    "\n",
    "        # this is generate mode\n",
    "        vocab_dist = self.output_layer(outputs)\n",
    "\n",
    "        # this is copy mode\n",
    "        # batch * length * output size\n",
    "        copy_score = self.copy_layer(self.encode_output)\n",
    "        # batch * length\n",
    "        copy_score = tf.einsum(\"ijk,ik->ij\",copy_score,outputs)\n",
    "\n",
    "        vocab_size = tf.shape(vocab_dist)[-1]\n",
    "        extended_vsize = vocab_size + self.max_oovs \n",
    "\n",
    "        batch_size = tf.shape(vocab_dist)[0]\n",
    "        extra_zeros = tf.zeros((batch_size, self.max_oovs))\n",
    "        # batch * extend vocab size\n",
    "        vocab_dists_extended = tf.concat(axis=-1, values=[vocab_dist, extra_zeros])\n",
    "\n",
    "        # this part is same as that of point generator, but using einsum is much simpler.\n",
    "        source_mask = tf.one_hot(self.source_extend_tokens,extended_vsize)\n",
    "        attn_dists_projected = tf.einsum(\"ijn,ij->in\", source_mask, copy_score)\n",
    "\n",
    "        final_dist = vocab_dists_extended + attn_dists_projected\n",
    "\n",
    "        # This is greeding search, need to test with beam search\n",
    "        last_ids = tf.argmax(final_dist, axis=-1, output_type=tf.int32)\n",
    "\n",
    "        # this is used to calculate p(y_t,c|.)\n",
    "        # safe softmax\n",
    "        final_dist_max = tf.expand_dims(tf.reduce_max(final_dist,axis=1), axis=1)\n",
    "        final_dist_exp = tf.reduce_sum(tf.exp(final_dist - final_dist_max),axis=1)\n",
    "        p_c = tf.exp(attn_dists_projected - final_dist_max) / tf.expand_dims(final_dist_exp, axis=1)\n",
    "        p_c = tf.einsum(\"ijn,in->ij\",source_mask,p_c)\n",
    "\n",
    "        state = CopyNetWrapperState(cell_state=cell_state, last_ids=last_ids, prob_c=p_c)\n",
    "        return final_dist, state\n",
    "\n",
    "    @property\n",
    "    def state_size(self):\n",
    "        \"\"\"size(s) of state(s) used by this cell.\n",
    "            It can be represented by an Integer, a TensorShape or a tuple of Integers\n",
    "            or TensorShapes.\n",
    "        \"\"\"\n",
    "        return CopyNetWrapperState(cell_state=self.cell.state_size, last_ids=tf.TensorShape([]),\n",
    "            prob_c = self.source_extend_tokens.shape[-1].value)\n",
    "\n",
    "    @property\n",
    "    def output_size(self):\n",
    "        \"\"\"Integer or TensorShape: size of outputs produced by this cell.\"\"\"\n",
    "        return self._output_size\n",
    "\n",
    "    def zero_state(self, batch_size, dtype):\n",
    "        with tf.name_scope(type(self).__name__ + \"ZeroState\", values=[batch_size]):\n",
    "            cell_state = self.cell.zero_state(batch_size, dtype)\n",
    "            last_ids = tf.zeros([batch_size], tf.int32) - 1\n",
    "            prob_c = tf.zeros([batch_size, tf.shape(self.encode_output)[1]], tf.float32)\n",
    "            return CopyNetWrapperState(cell_state=cell_state, last_ids=last_ids, prob_c=prob_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Stemmer:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 from_dict_size, learning_rate, max_oovs = 50):\n",
    "        \n",
    "        def lstm_cell(size = size_layer, reuse=False):\n",
    "            return tf.nn.rnn_cell.GRUCell(size, reuse=reuse)\n",
    "        \n",
    "        def attention(encoder_out, seq_len, reuse=False):\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = size_layer, \n",
    "                                                                    memory = encoder_out,\n",
    "                                                                    memory_sequence_length = seq_len)\n",
    "            return tf.contrib.seq2seq.AttentionWrapper(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(reuse=reuse) for _ in range(num_layers)]), \n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        self.source_oov_words = tf.placeholder(tf.int32, shape=[])\n",
    "        self.source_extend_tokens = tf.placeholder(tf.int32, shape=[None, None])\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        \n",
    "        condition = tf.less(decoder_input, from_dict_size)\n",
    "        self.decoder_input = self.Y\n",
    "        self.decoder_input_length = self.Y_seq_len\n",
    "        self.predict_count = tf.reduce_sum(self.decoder_input_length)\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        \n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = lstm_cell(size_layer // 2),\n",
    "                cell_bw = lstm_cell(size_layer // 2),\n",
    "                inputs = encoder_embedded,\n",
    "                sequence_length = self.X_seq_len,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_%d'%(n))\n",
    "            encoder_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "        \n",
    "        bi_state = tf.concat((state_fw, state_bw), -1)\n",
    "        self.initial_state = tuple(bi_state for _ in range(num_layers))\n",
    "        self.encoder_out = encoder_embedded\n",
    "        decoder_cell = attention(self.encoder_out, self.X_seq_len)\n",
    "        dense_layer = tf.layers.Dense(from_dict_size)\n",
    "        \n",
    "        initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=self.initial_state)\n",
    "        \n",
    "        self.decode_cell = CopyNetWrapper(decoder_cell, self.source_extend_tokens, max_oovs, \n",
    "                                          self.encoder_out, dense_layer, from_dict_size)\n",
    "        \n",
    "        dense_layer = None\n",
    "        \n",
    "        self.initial_state = self.decode_cell.zero_state(batch_size,\n",
    "                                                         tf.float32).clone(cell_state=initial_state)\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.decoder_input_length,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = self.decode_cell,\n",
    "                helper = training_helper,\n",
    "                initial_state = self.initial_state,\n",
    "                output_layer = dense_layer)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = encoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = self.decode_cell,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = self.initial_state,\n",
    "                output_layer = dense_layer)\n",
    "        \n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.X_seq_len))\n",
    "        \n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = predicting_decoder_output.sample_id\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 64\n",
    "learning_rate = 1e-3\n",
    "batch_size = 32\n",
    "epoch = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Stemmer(size_layer, num_layers, embedded_size, len(dictionary_from), \n",
    "                learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "def batching(X, Y):\n",
    "    s_, es_, oovs_, target_ = [], [], [], []\n",
    "    for x, y in zip(X, Y):\n",
    "        s,es,oovs = sent2idx(x, dictionary_from)\n",
    "        target = target2idx(y, oovs,dictionary_from)\n",
    "        s_.append(s)\n",
    "        es_.append(es)\n",
    "        oovs_.append(oovs)\n",
    "        target_.append(target)\n",
    "    s_ = pad_sequences(s_,padding='post')\n",
    "    es_ = pad_sequences(es_,padding='post')\n",
    "    target_ = pad_sequences(target_,padding='post')\n",
    "    maxlen = max([len(o) for o in oovs_])\n",
    "    return s_, es_, target_, maxlen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.26it/s, accuracy=0.889, cost=0.405]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:02<00:00, 31.03it/s, accuracy=0.873, cost=0.423]\n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:18, 13.76it/s, accuracy=0.9, cost=0.401]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, avg loss: 1.276579, avg accuracy: 0.663112\n",
      "epoch: 0, avg loss test: 0.442437, avg accuracy test: 0.898187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.87it/s, accuracy=0.953, cost=0.186] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.09it/s, accuracy=0.869, cost=0.447]\n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:16, 14.95it/s, accuracy=0.971, cost=0.167]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 0.281214, avg accuracy: 0.929588\n",
      "epoch: 1, avg loss test: 0.214072, avg accuracy test: 0.955397\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.85it/s, accuracy=0.967, cost=0.182] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 33.78it/s, accuracy=0.958, cost=0.138] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:17, 14.08it/s, accuracy=0.967, cost=0.172]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, avg loss: 0.141903, avg accuracy: 0.965294\n",
      "epoch: 2, avg loss test: 0.151230, avg accuracy test: 0.972862\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.82it/s, accuracy=0.977, cost=0.0747]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.26it/s, accuracy=0.966, cost=0.103] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:19, 12.59it/s, accuracy=0.984, cost=0.0479]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, avg loss: 0.090413, avg accuracy: 0.977463\n",
      "epoch: 3, avg loss test: 0.108176, avg accuracy test: 0.982170\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.46it/s, accuracy=0.988, cost=0.0638]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.25it/s, accuracy=0.961, cost=0.43]   \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:20, 12.17it/s, accuracy=0.993, cost=0.0526]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, avg loss: 0.064194, avg accuracy: 0.983682\n",
      "epoch: 4, avg loss test: 0.080756, avg accuracy test: 0.990902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.70it/s, accuracy=0.989, cost=0.041]  \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.11it/s, accuracy=0.985, cost=0.0534] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:18, 13.67it/s, accuracy=0.966, cost=0.156]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, avg loss: 0.046579, avg accuracy: 0.987737\n",
      "epoch: 5, avg loss test: 0.065392, avg accuracy test: 0.993854\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.60it/s, accuracy=0.969, cost=0.143]  \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.29it/s, accuracy=0.986, cost=0.144] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:17, 14.56it/s, accuracy=0.991, cost=0.021] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, avg loss: 0.036858, avg accuracy: 0.990067\n",
      "epoch: 6, avg loss test: 0.060458, avg accuracy test: 0.995457\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 15.14it/s, accuracy=1, cost=0.0133]     \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.08it/s, accuracy=0.978, cost=0.109] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:15, 15.79it/s, accuracy=0.992, cost=0.0229]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, avg loss: 0.042807, avg accuracy: 0.988341\n",
      "epoch: 7, avg loss test: 0.057284, avg accuracy test: 0.994005\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 13.93it/s, accuracy=0.989, cost=0.0325] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.26it/s, accuracy=1, cost=0.00304]   \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:16, 14.91it/s, accuracy=1, cost=0.0106] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, avg loss: 0.029407, avg accuracy: 0.991957\n",
      "epoch: 8, avg loss test: 0.059953, avg accuracy test: 0.995456\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 14.95it/s, accuracy=0.958, cost=0.131]  \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 33.20it/s, accuracy=1, cost=0.00232]    \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:15, 15.77it/s, accuracy=1, cost=0.00544]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 0.024976, avg accuracy: 0.993330\n",
      "epoch: 9, avg loss test: 0.063851, avg accuracy test: 0.995135\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 15.14it/s, accuracy=0.986, cost=0.0415] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.23it/s, accuracy=0.992, cost=0.0098] \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:15, 15.63it/s, accuracy=0.997, cost=0.00541]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, avg loss: 0.026247, avg accuracy: 0.992622\n",
      "epoch: 10, avg loss test: 0.052638, avg accuracy test: 0.995622\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:17<00:00, 15.54it/s, accuracy=0.99, cost=0.0363]  \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 33.74it/s, accuracy=0.975, cost=0.242]  \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:18, 13.28it/s, accuracy=0.975, cost=0.0897]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 11, avg loss: 0.054017, avg accuracy: 0.986113\n",
      "epoch: 11, avg loss test: 0.066696, avg accuracy test: 0.994378\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:16<00:00, 15.22it/s, accuracy=0.996, cost=0.00706]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 33.86it/s, accuracy=1, cost=0.00807]    \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:17, 14.33it/s, accuracy=0.997, cost=0.0161]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, avg loss: 0.020465, avg accuracy: 0.994062\n",
      "epoch: 12, avg loss test: 0.059716, avg accuracy test: 0.996420\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:16<00:00, 14.59it/s, accuracy=0.996, cost=0.00621]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.67it/s, accuracy=1, cost=0.00143]    \n",
      "train minibatch loop:   1%|          | 2/250 [00:00<00:15, 15.71it/s, accuracy=0.996, cost=0.0193]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 13, avg loss: 0.020216, avg accuracy: 0.994379\n",
      "epoch: 13, avg loss test: 0.055051, avg accuracy test: 0.996833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:16<00:00, 14.97it/s, accuracy=1, cost=0.00272]    \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 34.40it/s, accuracy=1, cost=0.00429]    "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, avg loss: 0.018538, avg accuracy: 0.995176\n",
      "epoch: 14, avg loss test: 0.048555, avg accuracy test: 0.997163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from sklearn.utils import shuffle\n",
    "import time\n",
    "\n",
    "for EPOCH in range(epoch):\n",
    "    lasttime = time.time()\n",
    "    total_loss, total_accuracy, total_loss_test, total_accuracy_test = 0, 0, 0, 0\n",
    "    train_X, train_Y = shuffle(train_X, train_Y)\n",
    "    test_X, test_Y = shuffle(test_X, test_Y)\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for k in pbar:\n",
    "        index = min(k+batch_size,len(train_X))\n",
    "        batch_x, batch_es, batch_y, maxlen = batching(train_X[k: index],\n",
    "                                                     train_Y[k: index])\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                 model.source_extend_tokens:batch_es,\n",
    "                                                model.Y:batch_y,\n",
    "                                                model.source_oov_words:maxlen})\n",
    "        total_loss += loss\n",
    "        total_accuracy += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for k in pbar:\n",
    "        index = min(k+batch_size,len(test_X))\n",
    "        batch_x, batch_es, batch_y, maxlen = batching(test_X[k: index],\n",
    "                                                     test_Y[k: index])\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                 model.source_extend_tokens:batch_es,\n",
    "                                                model.Y:batch_y,\n",
    "                                                model.source_oov_words:maxlen})\n",
    "        total_loss_test += loss\n",
    "        total_accuracy_test += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    total_loss /= (len(train_X) / batch_size)\n",
    "    total_accuracy /= (len(train_X) / batch_size)\n",
    "    total_loss_test /= (len(test_X) / batch_size)\n",
    "    total_accuracy_test /= (len(test_X) / batch_size)\n",
    "        \n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(EPOCH, total_loss, total_accuracy))\n",
    "    print('epoch: %d, avg loss test: %f, avg accuracy test: %f'%(EPOCH, total_loss_test, total_accuracy_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = sess.run(model.predicting_ids, \n",
    "                     feed_dict={model.X:batch_x,\n",
    "                                model.source_extend_tokens:batch_es,\n",
    "                                model.Y:batch_y,\n",
    "                                model.source_oov_words:maxlen})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "row 1\n",
      "BEFORE: <assumed>\n",
      "REAL AFTER: <assume>\n",
      "PREDICTED AFTER: <assum>e>e>e>e>m>m>m>m>e>i \n",
      "\n",
      "row 2\n",
      "BEFORE: <anaesthetizing>\n",
      "REAL AFTER: <anaesthetize>\n",
      "PREDICTED AFTER: <anaesthetie>e>e>e>e> \n",
      "\n",
      "row 3\n",
      "BEFORE: <contacts>\n",
      "REAL AFTER: <contact>\n",
      "PREDICTED AFTER: <contaat>t>t>t>t>t>t>t>t> \n",
      "\n",
      "row 4\n",
      "BEFORE: <arums>\n",
      "REAL AFTER: <arum>\n",
      "PREDICTED AFTER: <arm>m>m> \n",
      "\n",
      "row 5\n",
      "BEFORE: <boxcars>\n",
      "REAL AFTER: <boxcar>\n",
      "PREDICTED AFTER: <boxcaa>e>e>e>i>ar>r>r>r> \n",
      "\n",
      "row 6\n",
      "BEFORE: <batters>\n",
      "REAL AFTER: <batter>\n",
      "PREDICTED AFTER: <batteeeeeeeer>r>r> \n",
      "\n",
      "row 7\n",
      "BEFORE: <compositions>\n",
      "REAL AFTER: <composition>\n",
      "PREDICTED AFTER: <compositiiie>e>e>t \n",
      "\n",
      "row 8\n",
      "BEFORE: <acclimatizing>\n",
      "REAL AFTER: <acclimatize>\n",
      "PREDICTED AFTER: <acclimatie>e>e>e>e>e> \n",
      "\n",
      "row 9\n",
      "BEFORE: <broadcasted>\n",
      "REAL AFTER: <broadcast>\n",
      "PREDICTED AFTER: <broadcastee>e>e>e>e> \n",
      "\n",
      "row 10\n",
      "BEFORE: <building>\n",
      "REAL AFTER: <build>\n",
      "PREDICTED AFTER: <b<biild>e>e>e>e>iiiie> \n",
      "\n",
      "row 11\n",
      "BEFORE: <botanies>\n",
      "REAL AFTER: <botany>\n",
      "PREDICTED AFTER: <botany>e>e>e>e>e>e>e>e>e> \n",
      "\n",
      "row 12\n",
      "BEFORE: <bootstrapped>\n",
      "REAL AFTER: <bootstrap>\n",
      "PREDICTED AFTER: <bootstrap>p>p>e>e>e>e> \n",
      "\n",
      "row 13\n",
      "BEFORE: <adding-machines>\n",
      "REAL AFTER: <adding-machine>\n",
      "PREDICTED AFTER: <adding-mahhiie>e> \n",
      "\n",
      "row 14\n",
      "BEFORE: <befitting>\n",
      "REAL AFTER: <befit>\n",
      "PREDICTED AFTER: <befit>tt>e>e>e>e>tiii \n",
      "\n",
      "row 15\n",
      "BEFORE: <audiocassettes>\n",
      "REAL AFTER: <audiocassette>\n",
      "PREDICTED AFTER: <audiocassettttt \n",
      "\n",
      "row 16\n",
      "BEFORE: <bunions>\n",
      "REAL AFTER: <bunion>\n",
      "PREDICTED AFTER: <bunion>e>oon>e>n>iio \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(batch_x)):\n",
    "    print('row %d'%(i+1))\n",
    "    print('BEFORE:',''.join([rev_dictionary_from[n] for n in batch_x[i] if n not in [0,1,2,3]]))\n",
    "    print('REAL AFTER:',''.join([rev_dictionary_from[n] for n in batch_y[i] if n not in[0,1,2,3]]))\n",
    "    print('PREDICTED AFTER:',''.join([rev_dictionary_from[n] for n in predicted[i] if n not in[0,1,2,3]]),'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
